Disease classification on PlantVillage
In this chapter, we will design a CNN to perform the plant disease classification task on the PlantVillage dataset. This includes several steps, which are outlined below:
- Explore and preprocess the PlantVillage dataset
- Design an isotropic CNN architecture
- Train the CNN on the PlantVillage dataset
- Analyze accuracy of the CNN model from the angle of hierarchical confusion matrix
the PlantVillage dataset
The PlantVillage dataset is a collection of 54,305 images of 14 different plant species, belonging to 38 classes, 12 of which are healthy, 26 of which are diseased.
The dataset was created by the Penn State College of Agricultural Sciences and the International Institute of Tropical Agriculture as a resource for research and development of computer vision-based plant disease detection systems. The images in the dataset were collected from various sources, including research institutions and citizen scientists, and represent a wide variety of plant species and disease types.
The plants include fruits such as apple, blueberry, cherry, grape, orange, peach, raspberry, squash, strawberry and crops such as corn, soybean and vegetables such as pepper bell, potato, tomato. Each plant is in healthy status or in disease such as scab, rot, rust, and so on.
import pandas as pd
df = pd.read_csv('data/cls_count.csv')
df[['Plant', 'Disease', 'Count']]
The number of images of all the different types of plants are different with each other. Such a skewed distribution of the number of images in a dataset is called imbalanced. A imbalanced dataset is more difficult to train then a balanced dataset.
xticks = range(38)
ax = df.plot.bar(
x='Disease', y='Count',
title='Imbalanced distribution of the counts of images',
xlabel='Classes', xticks=xticks,
figsize=(10,5))
legend = ax.legend(loc=2)
Next, let us show 38 images, one for each category.
import os
root_dir = "data/plantvillage/"
samples = []
classes = os.listdir(root_dir)
for cls in classes:
cls_path = os.path.join(root_dir, cls)
if os.path.isdir(cls_path):
for img_name in os.listdir(cls_path):
img_path = os.path.join(cls_path, img_name)
samples.append(img_path)
from PIL import Image
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(figsize=(12., 20.))
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(6, 7), # creates 6x7 grid of axes
axes_pad=0.1, # pad between axes in inch.
)
for ax in grid:
ax.axis("off")
for ax, img_path in zip(grid, samples):
img = Image.open(img_path)
ax.axis("off")
ax.imshow(img)
A lightweight isotropic CNN architecture
CNNs can be divided into two types: isometric and pyramidical. Isotropic CNNs are a type of CNN that have equal size and shape for all layers throughout the network, while pyramidical CNNs use layers with varying sizes and shapes. The difference between them is illustrated by the figure below.
isotropic_vs_pyramidical = Image.open("data/isotropic_vs_pyramidical.PNG")
ax = plt.figure(figsize=(10,5))
plt.imshow(isotropic_vs_pyramidical)
plt.axis("off")
plt.show()
Isotropic CNNs emerged partially inspired by the state-of-the-art attention-based transformer architectures in computer vision that are isotropic architectures. Compared to pyramidical architectures, recent research discovers that isotropic architectures may improve performance or even meet state-of-the-art performance with a lot lighter layers.
We proposed a lightweight isotropic CNN, FoldNet, which achieved 99.84% accuracy in disease classification task on the PlantVillage dataset.
from IPython.display import Image, display
display(Image('data/foldnet_arch.png', width="80%"))
Hierarchical Confusion Matrix of PlantVillage
A confusion matrix is a visualization tool in machine learning to help people to evaluate the performance of a classification model. It is a tabular layout that compares predicted class labels against actual class labels over all data instances. The rows of the matrix represent the actual classes, while the columns represent the predicted classes. By analyzing the confusion matrix, we can determine how well the model is able to distinguish between different classes, as well as which classes are most often confused with one another. Popular performance metrics, such as accuracy, precision, recall, F-1 score could be derived from the confusion matrix.
The PlantVillage dataset has a tree-like hierarchical structure with three levels. The root node is the overall category, plant. The first level is the 14 specific plant species. The second level is the healthy or disease status of the particular plant. Thus we use hierarchical confusion matrix to capture the hierarchical structure in the dataset.
The following is an interactive widget to visualize the hierarchical confusion matrix of the FoldNet model when evaluating on the 10,861 testing images of the PlantVillage dataset.
The FoldNet model achieves 99.84% accuracy, with only 17 images are classified incorrectly. After quantitatively analyzing these 17 images, we find three interesting points need to be noted:
First, compared to incorrect classification within the same species, incorrect classification across species are very rare. Only 5 images are incorrectly identified as images of different plant species, while the other 12 images are identified correctly as to their species, even though incorrectly as to their disease status. This reflects the robustness of the FoldNet model, which can correctly predict the species of a image even if its prediction of the image's disease status is wrong.
Second, the 12 images that are incorrectly classified within the same species belong to two species, corn and tomato, rather than uniformly distributed in all the 14 species. 4 of the 12 images belong to corn, and the other 8 images belong to tomato. This reflects the complexity of the images of corn and tomato.
Third, in the 17 images that are classified incorrectly, several images are wrong in ground truth or captured in a extreme situation. For example, the first 'Cherry Healthy' image is actually field background; two "Tomato Late Blight" images have a very small foreground and a very large background.
display(Image('data/falsely_predicted_across_species.png', width="100%"))
display(Image('data/falsely_predicted_inner_species.png', width="100%"))